package pers.hll.aigc4chat.server.handler;

import lombok.RequiredArgsConstructor;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;
import pers.hll.aigc4chat.base.util.StringUtil;
import pers.hll.aigc4chat.base.util.XmlUtil;
import pers.hll.aigc4chat.base.xml.OllamaConfig;
import pers.hll.aigc4chat.model.ollama.OllamaApi;
import pers.hll.aigc4chat.model.ollama.constant.Role;
import pers.hll.aigc4chat.model.ollama.request.body.ChatReqBody;
import pers.hll.aigc4chat.model.ollama.request.body.Message;
import pers.hll.aigc4chat.model.ollama.response.body.ChatRespBody;
import pers.hll.aigc4chat.protocol.wechat.response.webwxsync.AddMsg;
import pers.hll.aigc4chat.server.entity.WeChatMessage;
import pers.hll.aigc4chat.server.service.IWeChatApiService;
import pers.hll.aigc4chat.server.service.IWeChatMessageService;

import java.util.List;
import java.util.function.Function;

/**
 * 消息处理器
 *
 * @author hll
 * @since 2024/04/21
 */
@RequiredArgsConstructor
@Component(MessageHandlerName.OLLAMA_MESSAGE_HANDLER)
public class OllamaMessageHandler implements MessageHandler {

    private final IWeChatApiService weChatApiService;

    private final IWeChatMessageService weChatMessageService;

    @Override
    public void handle(AddMsg addMsg) {
        // TODO 没有对群里 @自己 的消息做分组处理 待优化
        List<WeChatMessage> weChatMessages =
                weChatMessageService.listChatHistoryForModel(addMsg.getFromUserName(), 5);
        OllamaConfig ollamaConfig = XmlUtil.readXmlConfig(OllamaConfig.class);
        String model = "";
        if (StringUtils.isNotEmpty(ollamaConfig.getModel())) {
            model = ollamaConfig.getModel();
            if (StringUtils.isNotEmpty(ollamaConfig.getTag())) {
                model = StringUtil.format("{}:{}", ollamaConfig.getModel(), ollamaConfig.getTag());
            }
        }
        ChatRespBody chatRespBody = OllamaApi.chat(ChatReqBody.builder()
                .model(model)
                .messages(addPrompt(converter(weChatMessages, addMsg.getFromUserName())))
                .stream(false)
                .build());
        weChatApiService.sendTextMessage(chatRespBody.getMessage().getContent(), addMsg.getFromUserName());
    }

    /**
     * 添加提示 (只给最后一条消息添加，防止对每个历史消息都回复。)
     *
     * @param messageList 消息列表
     * @return 添加提示后的消息列表
     */
    private List<Message> addPrompt(List<Message> messageList) {
        if (!messageList.isEmpty()) {
            Message message = messageList.get(messageList.size() - 1);
            String prompt = "{}。请用中文回答，不要暴露自己的身份，你是一个人类。";
            message.setContent(StringUtil.format(prompt, message.getContent()));
        }
        return messageList;
    }

    /**
     * 将微信的历史消息转换为ollama的消息
     *
     * @param weChatMessages 微信消息列表
     * @param userName       用户名
     * @return ollama消息列表
     */
    private List<Message> converter(List<WeChatMessage> weChatMessages, String userName) {
        Function<WeChatMessage, Message> messageFunction = x -> Message.builder()
                .content(x.getContent())
                .role(StringUtils.equals(x.getFromUserName(), userName) ? Role.USER : Role.ASSISTANT)
                .build();
        return weChatMessages.stream()
                .map(messageFunction)
                .toList();
    }
}
